{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "TRIAL_NAME='trial'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "CONF={\n",
    "    'niter':5,\n",
    "    'ntest':50,\n",
    "    'GPU':0,\n",
    "    'BS':4,\n",
    "    'test_BS':256,\n",
    "    'N_neg':3,\n",
    "    'name':TRIAL_NAME,\n",
    "    'nz':2048,\n",
    "    'ng':128,\n",
    "    'seed':10715,\n",
    "    'data_dir':'/DataSet/COCO', #Set it to where your COCO dataset is\n",
    "    'dataType':'train2017',\n",
    "    'valType':'val2017',\n",
    "    'testType':'test2017',\n",
    "    'max_len':512,\n",
    "    'hidden_size':768,\n",
    "    'bert_hdim':3072,\n",
    "    'LAMBDAs':0.005,\n",
    "    'use_super':False,\n",
    "    'test_classes':80,\n",
    "    'n_trials':5,\n",
    "    'bert_pretrained':True,  #pretrain bert\n",
    "    'resnet_pretrained':True  #pretrain resnet\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=str(CONF['GPU'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import texar.torch as tx\n",
    "import random\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.parallel\n",
    "from torch.nn import functional as F\n",
    "import torch.backends.cudnn as cudnn\n",
    "import torch.optim as optim\n",
    "import torch.utils.data\n",
    "import torchvision\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "import torchvision.datasets as dset\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision.utils as vutils\n",
    "import numpy as np\n",
    "from torch import autograd\n",
    "import multiprocessing\n",
    "from PIL import Image\n",
    "from sklearn import metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\" if True else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "time.sleep(2)\n",
    "import shutil\n",
    "tb_dir=os.path.join('./runs', CONF['name'] + \"_GLOBAL\")\n",
    "shutil.rmtree(tb_dir, ignore_errors=True)\n",
    "time.sleep(5)\n",
    "global_writer = SummaryWriter(log_dir=tb_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "random.seed(CONF['seed'])\n",
    "torch.manual_seed(CONF['seed'])\n",
    "np.random.seed(CONF['seed'])\n",
    "cudnn.benchmark = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = CONF['BS']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "T = transforms.Compose([\n",
    "    transforms.RandomResizedCrop((224,224), scale=(0.3, 1.0), ratio=(0.75, 1.3333333333333333)),\n",
    "    transforms.ColorJitter(brightness=.1, contrast=.05, saturation=.05, hue=.05),\n",
    "    transforms.RandomHorizontalFlip(),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
    "])\n",
    "T_test = transforms.Compose([\n",
    "    transforms.Resize((224,224)),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision.datasets.vision import VisionDataset\n",
    "class CocoClassification(VisionDataset):\n",
    "    \"\"\"`MS Coco Detection <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset.\n",
    "\n",
    "    Args:\n",
    "        root (string): Root directory where images are downloaded to.\n",
    "        annFile (string): Path to json annotation file.\n",
    "        transform (callable, optional): A function/transform that  takes in an PIL image\n",
    "            and returns a transformed version. E.g, ``transforms.ToTensor``\n",
    "        target_transform (callable, optional): A function/transform that takes in the\n",
    "            target and transforms it.\n",
    "        transforms (callable, optional): A function/transform that takes input sample and its target as entry\n",
    "            and returns a transformed version.\n",
    "    \"\"\"\n",
    "\n",
    "    def sample_class(self, k):\n",
    "        if CONF['use_super']:\n",
    "            self.classes = np.array([\"vehicle\", \"outdoor\", \"indoor\", \"person\", \"appliance\", \"furniture\", \"sports\", \"food\", \"kitchen\", \"accessory\", \"electronic\", \"animal\"])\n",
    "            self.class_description = [\"vehicle\", \"outdoor\", \"indoor\", \"person\", \"appliance\", \"furniture\", \"sports\", \"food\", \"kitchen\", \"accessory\", \"electronic\", \"animal\"]\n",
    "            return\n",
    "        class_list = self.coco.getCatIds()\n",
    "        self.classes = np.sort(np.random.choice(class_list, size=k, replace=False))\n",
    "        self.class_description = self.coco.loadCats(self.classes)\n",
    "        arr = []\n",
    "        for catId in self.classes:\n",
    "            arr+=self.coco.getImgIds(catIds=[catId])\n",
    "        self.ids = sorted(list(set(arr)))\n",
    "        \n",
    "    def __init__(self, root, annFile, transform=None, target_transform=None, transforms=None):\n",
    "        super(CocoClassification, self).__init__(root, transforms, transform, target_transform)\n",
    "        from pycocotools.coco import COCO\n",
    "        self.coco = COCO(annFile)\n",
    "        self.sample_class(len(self.coco.getCatIds()))\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            index (int): Index\n",
    "\n",
    "        Returns:\n",
    "            tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.\n",
    "        \"\"\"\n",
    "        coco = self.coco\n",
    "        img_id = self.ids[index]\n",
    "        ann_ids = coco.getAnnIds(imgIds=img_id)\n",
    "        cat_ids = [ann['category_id'] for ann in coco.loadAnns(ann_ids)]\n",
    "        target = coco.loadCats(cat_ids)\n",
    "        if CONF['use_super']:\n",
    "            target = np.array([x['supercategory'] for x in target])\n",
    "        else:\n",
    "            target = np.array([x['id'] for x in target if x['id'] in self.classes])\n",
    "        targets = torch.FloatTensor([1 if (c in target) else 0 for c in self.classes])\n",
    "        path = coco.loadImgs(img_id)[0]['file_name']\n",
    "        img = Image.open(os.path.join(self.root, path)).convert('RGB')\n",
    "        if self.transforms is not None:\n",
    "            img, targets = self.transforms(img, targets)\n",
    "\n",
    "        return img, targets\n",
    "\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.ids)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = dset.CocoCaptions(root = '{}/{}'.format(CONF['data_dir'],CONF['dataType']),\n",
    "                        annFile = '{}/annotations/captions_{}.json'.format(CONF['data_dir'],CONF['dataType']),\n",
    "                        transform=T)\n",
    "\n",
    "clas_set = CocoClassification(root = '{}/{}'.format(CONF['data_dir'],CONF['dataType']),\n",
    "                        annFile = '{}/annotations/instances_{}.json'.format(CONF['data_dir'],CONF['dataType']),\n",
    "                        transform=T_test)\n",
    "val_set = CocoClassification(root = '{}/{}'.format(CONF['data_dir'],CONF['valType']),\n",
    "                        annFile = '{}/annotations/instances_{}.json'.format(CONF['data_dir'],CONF['valType']),\n",
    "                        transform=T_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size*(CONF['N_neg']+1),\n",
    "                                           shuffle=True, num_workers=2, pin_memory=True, drop_last=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "val_loader = torch.utils.data.DataLoader(val_set, batch_size=CONF['test_BS'],shuffle=False, num_workers=8, pin_memory=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hparams = {\n",
    "    \"pretrained_model_name\": \"bert-base-uncased\",\n",
    "    \"vocab_file\": None,\n",
    "    \"max_len\": CONF['max_len'],\n",
    "    \"unk_token\": \"[UNK]\",\n",
    "    \"sep_token\": \"[SEP]\",\n",
    "    \"pad_token\": \"[PAD]\",\n",
    "    \"cls_token\": \"[CLS]\",\n",
    "    \"mask_token\": \"[MASK]\",\n",
    "    \"tokenize_chinese_chars\": True,\n",
    "    \"do_lower_case\": True,\n",
    "    \"do_basic_tokenize\": True,\n",
    "    \"non_split_tokens\": None,\n",
    "    \"name\": \"bert_tokenizer\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = tx.data.BERTTokenizer(hparams=hparams, pretrained_model_name='bert-base-uncased')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bert_hparams = {\n",
    " 'embed': {'dim': CONF['hidden_size'], 'name': 'word_embeddings'},\n",
    " 'vocab_size': 30522,\n",
    " 'segment_embed': {'dim': CONF['hidden_size'], 'name': 'token_type_embeddings'},\n",
    " 'type_vocab_size': 2,\n",
    " 'position_embed': {'dim': CONF['hidden_size'], 'name': 'position_embeddings'},\n",
    " 'position_size': CONF['max_len'],\n",
    " 'encoder': {'dim': CONF['hidden_size'],\n",
    "  'embedding_dropout': 0.1,\n",
    "  'multihead_attention': {'dropout_rate': 0.1,\n",
    "   'name': 'self',\n",
    "   'num_heads': 6,\n",
    "   'num_units': CONF['hidden_size'],\n",
    "   'output_dim': CONF['hidden_size'],\n",
    "   'use_bias': True},\n",
    "  'name': 'encoder',\n",
    "  'num_blocks': 4,\n",
    "  'poswise_feedforward': {'layers': [{'kwargs': {'in_features': CONF['hidden_size'],\n",
    "      'out_features': CONF['bert_hdim'],\n",
    "      'bias': True},\n",
    "     'type': 'Linear'},\n",
    "    {'type': 'BertGELU'},\n",
    "    {'kwargs': {'in_features': CONF['bert_hdim'], 'out_features': CONF['hidden_size'], 'bias': True},\n",
    "     'type': 'Linear'}]},\n",
    "  'residual_dropout': 0.1,\n",
    "  'use_bert_config': True},\n",
    " 'hidden_size': CONF['hidden_size'],\n",
    " 'initializer': None,\n",
    " 'name': 'bert_encoder',\n",
    " 'pretrained_model_name':None}\n",
    "\n",
    "if CONF['bert_pretrained']:\n",
    "    bert_hparams[\"pretrained_model_name\"]=\"bert-base-uncased\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Flatten(nn.Module):\n",
    "    def forward(self, x):\n",
    "        x = x.view(x.size()[0], -1)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PC_Embedder(nn.Module):\n",
    "\n",
    "\n",
    "    def __init__(self, k, nz, hparams):\n",
    "        super(PC_Embedder, self).__init__()\n",
    "        self.bert = tx.modules.BERTEncoder(hparams=hparams)\n",
    "        self.q = nn.Linear(CONF['hidden_size'], nz)\n",
    "        resnet50 = torchvision.models.resnet50(pretrained=CONF['resnet_pretrained'])\n",
    "        modules=list(resnet50.children())[:-1]\n",
    "        modules.append(Flatten())\n",
    "        self.p = nn.Sequential(*modules)\n",
    "        self.k = k+1\n",
    "        self.G = nn.Sequential(\n",
    "            nn.Linear(2048, 512, bias=True),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(512, nz, bias=True),\n",
    "        )\n",
    "\n",
    "    def forward(self, v):\n",
    "        return self.p(v)\n",
    "\n",
    "    def get_loss(self, v, v1, LAMBDA):\n",
    "        inputs, segment_ids = v1\n",
    "        _, v1 = self.bert(inputs=inputs, segment_ids=segment_ids)\n",
    "        batch_size = v.shape[0]//self.k\n",
    "        z = self.G(self.p(v))\n",
    "        z = z.squeeze().view(batch_size, self.k, z.shape[1])[:,0,:]\n",
    "        z_l = z.unsqueeze(1).expand(z.shape[0],self.k,z.shape[1]).contiguous()\n",
    "        z_p = self.q(v1).view(z_l.shape).contiguous()\n",
    "        l1 = F.log_softmax(torch.sum(z_l*z_p, dim = -1), dim=1)[:,0]\n",
    "        l2 = torch.sum((z - z_p[:,0,:])**2, dim=-1)\n",
    "        return torch.mean(- l1 + LAMBDA*l2, dim=0),l1,l2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Tracker(object):\n",
    "\n",
    "    def __init__(self, VARS, ranks):\n",
    "        self.var_dict = dict(zip(VARS, [(1000000.0 if ranks[i] else -1000000.0, int(ranks[i])) for i in range(len(VARS))]))\n",
    "        \n",
    "    def update(self, d):\n",
    "        for k,v in d.items():\n",
    "            o,r = self.var_dict[k]\n",
    "            self.var_dict[k] = (np.minimum(v,o) if r else np.maximum(v,o), r)\n",
    "    \n",
    "    def return_dict(self):\n",
    "        D = dict()\n",
    "        for k,(v,_) in self.var_dict.items():\n",
    "            D[k]=v\n",
    "        return D"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_class = nn.Linear(CONF['nz'],CONF['test_classes']).cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = nn.BCEWithLogitsLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import notebook\n",
    "tnrange=notebook.tnrange\n",
    "tqdm_notebook = notebook.tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def hamming_score(y_true, y_pred, normalize=True, sample_weight=None):\n",
    "    '''\n",
    "    Compute the Hamming score (a.k.a. label-based accuracy) for the multi-label case\n",
    "    https://stackoverflow.com/q/32239577/395857\n",
    "    '''\n",
    "    acc_list = []\n",
    "    for i in range(y_true.shape[0]):\n",
    "        set_true = set( np.where(y_true[i])[0] )\n",
    "        set_pred = set( np.where(y_pred[i])[0] )\n",
    "        tmp_a = None\n",
    "        if len(set_true) == 0 and len(set_pred) == 0:\n",
    "            tmp_a = 1\n",
    "        else:\n",
    "            tmp_a = len(set_true.intersection(set_pred))/\\\n",
    "                    float( len(set_true.union(set_pred)) )\n",
    "        acc_list.append(tmp_a)\n",
    "    return np.mean(acc_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "METRICS = ['train_avg', 'val_avg', 'subset_acc', 'hamming_loss', 'hamming_score',\n",
    "           'micro_f1', 'macro_f1', 'micro_roc_auc', 'macro_roc_auc',\n",
    "           'micro_precision', 'macro_precision', 'micro_recall', 'macro_recall']\n",
    "RANKS = [0, 0, 0, 1, 0,\n",
    "         0, 0, 0, 0,\n",
    "         0, 0, 1, 1]\n",
    "\n",
    "assert(len(METRICS)==len(RANKS))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def trial(enc, opt, LAMBDA, trial_number):\n",
    "    c=0\n",
    "    # Clear any logs from previous runs\n",
    "    import shutil\n",
    "    tb_dir=os.path.join('./runs', CONF['name'] + \"_LAMBDA_{}_{}\".format(LAMBDA, trial_number))\n",
    "    shutil.rmtree(tb_dir, ignore_errors=True)\n",
    "    time.sleep(5)\n",
    "    tracker = Tracker(METRICS, RANKS)\n",
    "    writer = SummaryWriter(log_dir=tb_dir)\n",
    "    for it in tnrange(CONF['niter'], desc=\"training with LAMBDA:{}\".format(LAMBDA)):\n",
    "\n",
    "        test_class.reset_parameters()\n",
    "        clas_set.sample_class(CONF['test_classes'])\n",
    "        clas_loader = torch.utils.data.DataLoader(clas_set, batch_size=CONF['test_BS'],shuffle=True, num_workers=8, pin_memory=False)\n",
    "        with torch.no_grad():\n",
    "            enc.eval()\n",
    "            test = []\n",
    "            for img,lbl in tqdm_notebook(clas_loader,leave=False,desc=\"generating train set\"):\n",
    "                test.append((enc(img.cuda()).data.cpu(), lbl))\n",
    "            enc.train()\n",
    "\n",
    "        opt_test = optim.Adam(test_class.parameters(), lr=1e-2)\n",
    "        for test_it in tnrange(CONF['ntest'],leave=False,desc=\"training linear classifier\"):\n",
    "            random.shuffle(test)\n",
    "            for z,lbl in test:\n",
    "                pred = test_class(z.cuda())\n",
    "                loss = criterion(pred,lbl.cuda())\n",
    "                opt_test.zero_grad()\n",
    "                loss.backward()\n",
    "                opt_test.step()\n",
    "\n",
    "        with torch.no_grad():\n",
    "            M = dict()\n",
    "            count = 0\n",
    "            corrects = torch.zeros(CONF['test_classes'])\n",
    "            for z,lbl in test:\n",
    "                count += z.shape[0]\n",
    "                pred = torch.sigmoid(test_class(z.cuda()))>.5\n",
    "                corrects += torch.sum(torch.eq(pred.cpu(), lbl), dim=0)\n",
    "            acc = (corrects/float(count))\n",
    "            writer.add_scalar(\"acc/train_avg\", torch.mean(acc).item(), global_step=it)\n",
    "            writer.add_histogram('acc_train', acc.data.cpu().numpy(), global_step=it)\n",
    "            M['train_avg'] = torch.mean(acc).item()\n",
    "\n",
    "            count = 0\n",
    "            corrects = torch.zeros(CONF['test_classes'])\n",
    "            LBL, Y = [], []\n",
    "            for img,lbl in val_loader:\n",
    "                LBL.append(lbl.numpy())\n",
    "                count += img.shape[0]\n",
    "                z = enc(img.cuda())\n",
    "                pred = torch.sigmoid(test_class(z))>.5\n",
    "                Y.append(pred.data.cpu().numpy())\n",
    "                corrects += torch.sum(torch.eq(pred.cpu(), lbl), dim=0)\n",
    "            acc = (corrects/float(count))\n",
    "            writer.add_scalar(\"acc/val_avg\", torch.mean(acc).item(), global_step=it)\n",
    "            writer.add_histogram('acc_val', acc.data.cpu().numpy(), global_step=it)\n",
    "\n",
    "            M['val_avg'] = torch.mean(acc).item()\n",
    "            Y, LBL = np.concatenate(Y,axis=0), np.concatenate(LBL,axis=0)\n",
    "            M['subset_acc'] = metrics.accuracy_score(LBL, Y, normalize=True)\n",
    "            M['hamming_loss'] = metrics.hamming_loss(LBL, Y)\n",
    "            M['hamming_score'] = hamming_score(LBL, Y)\n",
    "            M['micro_f1'] = metrics.f1_score(LBL, Y, average='micro')\n",
    "            M['macro_f1'] = metrics.f1_score(LBL, Y, average='macro')\n",
    "            M['micro_roc_auc'] = metrics.roc_auc_score(LBL, Y, average='micro')\n",
    "            M['macro_roc_auc'] = metrics.roc_auc_score(LBL, Y, average='macro')\n",
    "            M['micro_precision'] = metrics.precision_score(LBL,Y,average='micro')\n",
    "            M['macro_precision'] = metrics.precision_score(LBL,Y,average='macro')\n",
    "            M['micro_recall'] = metrics.recall_score(LBL,Y,average='micro')\n",
    "            M['macro_recall'] = metrics.recall_score(LBL,Y,average='macro')\n",
    "            \n",
    "            for k,v in M.items():\n",
    "                writer.add_scalar(\"metrics/{}\".format(k), v, global_step=it)\n",
    "            tracker.update(M)\n",
    "            writer.flush()\n",
    "\n",
    "        for img,ann in tqdm_notebook(train_loader, leave=False):\n",
    "            ann = np.take_along_axis(np.array(ann), np.random.randint(0, len(ann)-1, size=(1,img.shape[0])),0).squeeze()\n",
    "            inputs, segment_ids = [],[]\n",
    "            for s in ann:\n",
    "                x,y,_ = tokenizer.encode_text(s)\n",
    "                inputs.append(x)\n",
    "                segment_ids.append(y)\n",
    "\n",
    "            inputs = torch.LongTensor(inputs).cuda()\n",
    "            segment_ids = torch.LongTensor(segment_ids).cuda()\n",
    "            img = img.cuda()\n",
    "            loss, l1, l2 = enc.get_loss(img,(inputs,segment_ids), LAMBDA)\n",
    "            opt.zero_grad()\n",
    "            loss.backward()\n",
    "            opt.step()\n",
    "            writer.add_scalar(\"loss/loss\", loss.item(), global_step=c)\n",
    "            writer.add_scalars(\"loss/parts\", {'l1':- l1.data.mean().item(), 'l2':l2.data.mean().item()}, global_step=c)\n",
    "            c+=1\n",
    "        torch.save(enc.state_dict(), \"./models/{}.pth\".format(CONF['name'] + \"_trial_{}_LAMBDA_{}_it_{}\".format(trial_number,LAMBDA,it)))\n",
    "    writer.close()\n",
    "    return tracker.return_dict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "enc = PC_Embedder(CONF['N_neg'], CONF['nz'], bert_hparams)\n",
    "enc.to(device)\n",
    "opt = optim.Adam(enc.parameters(), lr=1e-4)\n",
    "global_writer.add_hparams(hparam_dict={'lambda': CONF['LAMBDAs'], 'trial':t}, metric_dict=trial(enc, opt, CONF['LAMBDAs'], t))\n",
    "global_writer.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
